{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e410b6c4",
   "metadata": {},
   "source": [
    "# Build a Qusetion Answering Engine in Minutes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bed6f24",
   "metadata": {},
   "source": [
    "This notebook illustrates how to build a question answering engine from scratch using [Milvus](https://milvus.io/) and [Towhee](https://towhee.io/). Milvus is the most advanced open-source vector database built for AI applications and supports nearest neighbor embedding search across tens of millions of entries, and Towhee is a framework that provides ETL for unstructured data using SoTA machine learning models.\n",
    "\n",
    "We will go through question answering procedures and evaluate performance. Moreover, we managed to make the core functionality as simple as almost 10 lines of code with Towhee, so that you can start hacking your own question answering engine."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4883e577",
   "metadata": {},
   "source": [
    "## Preparations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49110b91",
   "metadata": {},
   "source": [
    "### Install Dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0117995a",
   "metadata": {},
   "source": [
    "First we need to install dependencies such as towhee, towhee.models and gradio."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "c9ba3850",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "! python -m pip install -q towhee towhee.models gradio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90db0c5",
   "metadata": {},
   "source": [
    "### Prepare the Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1eceb58",
   "metadata": {},
   "source": [
    "There is a subset of the  [InsuranceQA Corpus](https://github.com/shuzi/insuranceQA)  (1000 pairs of questions and answers) used in this demo, everyone can download on [Github](https://github.com/towhee-io/examples/releases/download/data/question_answer.csv)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d1436a9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100  595k  100  595k    0     0   286k      0  0:00:02  0:00:02 --:--:--  666k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/question_answer.csv -O"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4abdc0a",
   "metadata": {},
   "source": [
    "**question_answer.csv**: a file containing question and the answer.\n",
    "\n",
    "Let's take a quick look:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d652efea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>Is  Disability  Insurance  Required  By  Law?</td>\n",
       "      <td>Not generally. There are five states that requ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>Can  Creditors  Take  Life  Insurance  After  ...</td>\n",
       "      <td>If the person who passed away was the one with...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Does  Travelers  Insurance  Have  Renters  Ins...</td>\n",
       "      <td>One of the insurance carriers I represent is T...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>Can  I  Drive  A  New  Car  Home  Without  Ins...</td>\n",
       "      <td>Most auto dealers will not let you drive the c...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>Is  The  Cash  Surrender  Value  Of  Life  Ins...</td>\n",
       "      <td>Cash surrender value comes only with Whole Lif...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                           question  \\\n",
       "0   0      Is  Disability  Insurance  Required  By  Law?   \n",
       "1   1  Can  Creditors  Take  Life  Insurance  After  ...   \n",
       "2   2  Does  Travelers  Insurance  Have  Renters  Ins...   \n",
       "3   3  Can  I  Drive  A  New  Car  Home  Without  Ins...   \n",
       "4   4  Is  The  Cash  Surrender  Value  Of  Life  Ins...   \n",
       "\n",
       "                                              answer  \n",
       "0  Not generally. There are five states that requ...  \n",
       "1  If the person who passed away was the one with...  \n",
       "2  One of the insurance carriers I represent is T...  \n",
       "3  Most auto dealers will not let you drive the c...  \n",
       "4  Cash surrender value comes only with Whole Lif...  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('question_answer.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "309bfb43",
   "metadata": {},
   "source": [
    "To use the dataset to get answers, let's first define the dictionary:\n",
    "\n",
    "- `id_answer`: a dictionary of id and corresponding answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d98b309",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_answer = df.set_index('id')['answer'].to_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c5a0858",
   "metadata": {},
   "source": [
    "### Create Milvus Collection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efb06a01",
   "metadata": {},
   "source": [
    "Before getting started, please make sure that you have started a [Milvus service](https://milvus.io/docs/install_standalone-docker.md). This notebook uses [milvus 2.2.10](https://milvus.io/docs/v2.2.x/install_standalone-docker.md) and [pymilvus 2.2.11](https://milvus.io/docs/release_notes.md#2210)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8048bf6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m pip install -q pymilvus==2.2.11"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87ba2b23",
   "metadata": {},
   "source": [
    "Next to define the function `create_milvus_collection` to create collection in Milvus that uses the [L2 distance metric](https://milvus.io/docs/metric.md#Euclidean-distance-L2) and an [IVF_FLAT index](https://milvus.io/docs/index.md#IVF_FLAT)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "22c19982",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
    "\n",
    "connections.connect(host='127.0.0.1', port='19530')\n",
    "\n",
    "def create_milvus_collection(collection_name, dim):\n",
    "    if utility.has_collection(collection_name):\n",
    "        utility.drop_collection(collection_name)\n",
    "    \n",
    "    fields = [\n",
    "    FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),\n",
    "    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)\n",
    "    ]\n",
    "    schema = CollectionSchema(fields=fields, description='reverse image search')\n",
    "    collection = Collection(name=collection_name, schema=schema)\n",
    "\n",
    "    # create IVF_FLAT index for collection.\n",
    "    index_params = {\n",
    "        'metric_type':'L2',\n",
    "        'index_type':\"IVF_FLAT\",\n",
    "        'params':{\"nlist\":2048}\n",
    "    }\n",
    "    collection.create_index(field_name=\"embedding\", index_params=index_params)\n",
    "    return collection\n",
    "\n",
    "collection = create_milvus_collection('question_answer', 768)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9724ba28",
   "metadata": {},
   "source": [
    "## Question Answering Engine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01e1ac7e",
   "metadata": {},
   "source": [
    "In this section, we will show how to build our question answering engine using Milvus and Towhee. The basic idea behind question answering is to use Towhee to generate embedding from the question dataset and compare the input question with the embedding stored in Milvus.\n",
    "\n",
    "[Towhee](https://towhee.io/) is a machine learning framework that allows the creation of data processing pipelines, and it also provides predefined operators for implementing insert and query operations in Milvus.\n",
    "\n",
    "<img src=\"./workflow.png\" width = \"60%\" height = \"60%\" align=center />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c0188bf",
   "metadata": {},
   "source": [
    "### Load question embedding into Milvus"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a654fdc",
   "metadata": {},
   "source": [
    "We first generate embedding from question text with [dpr](https://towhee.io/text-embedding/dpr) operator and insert the embedding into Milvus. Towhee provides a [method-chaining style API](https://towhee.readthedocs.io/en/main/index.html) so that users can assemble a data processing pipeline with operators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "13b7beea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2min 37s, sys: 3min 59s, total: 6min 37s\n",
      "Wall time: 1min 27s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from towhee import pipe, ops\n",
    "import numpy as np\n",
    "from towhee.datacollection import DataCollection\n",
    "\n",
    "insert_pipe = (\n",
    "    pipe.input('id', 'question', 'answer')\n",
    "        .map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))\n",
    "        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))\n",
    "        .map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))\n",
    "        .output()\n",
    ")\n",
    "\n",
    "import csv\n",
    "with open('question_answer.csv', encoding='utf-8') as f:\n",
    "    reader = csv.reader(f)\n",
    "    next(reader)\n",
    "    for row in reader:\n",
    "        insert_pipe(*row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1adbb2e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of inserted data is 1000.\n"
     ]
    }
   ],
   "source": [
    "print('Total number of inserted data is {}.'.format(collection.num_entities))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deb269f4",
   "metadata": {},
   "source": [
    "#### Explanation of Data Processing Pipeline\n",
    "\n",
    "Here is detailed explanation for each line of the code:\n",
    "\n",
    "`pipe.input('id', 'question', 'answer')`: Get three inputs, namely question's id, quesion's text and question's answer;\n",
    "\n",
    "`map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))`: Use the `acebook/dpr-ctx_encoder-single-nq-base` model to generate the question embedding vector with the [dpr operator](https://towhee.io/text-embedding/dpr) in towhee hub;\n",
    "\n",
    "`map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))`: normalize the embedding vector;\n",
    "\n",
    "`map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))`: insert question embedding vector into Milvus;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b35657d0",
   "metadata": {},
   "source": [
    "### Ask Question with Milvus and Towhee"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd02adfc",
   "metadata": {},
   "source": [
    "Now that embedding for question dataset have been inserted into Milvus, we can ask question with Milvus and Towhee. Again, we use Towhee to load the input question, compute a embedding, and use it as a query in Milvus. Because Milvus only outputs IDs and distance values, we provide the `id_answers` dictionary to get the answers based on IDs and display."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95913f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">question</th> <th style=\"text-align: center; font-size: 130%; border: none;\">answer</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">Is  Disability  Insurance  Required  By  Law?</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>Not generally. There are five states that require most all employers carry short term disability insurance on their employees. T...</br></td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.12 s, sys: 375 ms, total: 1.49 s\n",
      "Wall time: 16.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "collection.load()\n",
    "ans_pipe = (\n",
    "    pipe.input('question')\n",
    "        .map('question', 'vec', ops.text_embedding.dpr(model_name=\"facebook/dpr-ctx_encoder-single-nq-base\"))\n",
    "        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))\n",
    "        .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))\n",
    "        .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])\n",
    "        .output('question', 'answer')\n",
    ")\n",
    "\n",
    "\n",
    "ans = ans_pipe('Is  Disability  Insurance  Required  By  Law?')\n",
    "ans = DataCollection(ans)\n",
    "ans.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfb05a79",
   "metadata": {},
   "source": [
    "Then we can get the answer about 'Is  Disability  Insurance  Required  By  Law?'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cb1a8f96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Not generally. There are five states that require most all employers carry short term disability insurance on their employees. These states are: California, Hawaii, New Jersey, New York, and Rhode Island. Besides this mandatory short term disability law, there is no other legislative imperative for someone to purchase or be covered by disability insurance.']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ans[0]['answer']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01bef722",
   "metadata": {},
   "source": [
    "## Release a Showcase"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c71cace8",
   "metadata": {},
   "source": [
    "We've done an excellent job on the core functionality of our question answering engine. Now it's time to build a showcase with interface. [Gradio](https://gradio.app/) is a great tool for building demos. With Gradio, we simply need to wrap the data processing pipeline via a `chat` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "65d42114",
   "metadata": {},
   "outputs": [],
   "source": [
    "import towhee\n",
    "def chat(message, history):\n",
    "    history = history or []\n",
    "    ans_pipe = (\n",
    "        pipe.input('question')\n",
    "            .map('question', 'vec', ops.text_embedding.dpr(model_name=\"facebook/dpr-ctx_encoder-single-nq-base\"))\n",
    "            .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))\n",
    "            .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))\n",
    "            .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])\n",
    "            .output('question', 'answer')\n",
    "    )\n",
    "\n",
    "    response = ans_pipe(message).get()[1][0]\n",
    "    history.append((message, response))\n",
    "    return history, history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "065523a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "Running on public URL: https://7efbf90b-a281-48f9.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://7efbf90b-a281-48f9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gradio\n",
    "\n",
    "collection.load()\n",
    "chatbot = gradio.Chatbot(color_map=(\"green\", \"gray\"))\n",
    "interface = gradio.Interface(\n",
    "    chat,\n",
    "    [\"text\", \"state\"],\n",
    "    [chatbot, \"state\"],\n",
    "    allow_screenshot=False,\n",
    "    allow_flagging=\"never\",\n",
    ")\n",
    "interface.launch(inline=True, share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23806967",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "f7dd10cdbe9a9c71f7e71741efd428241b5f9fa0fecdd29ae07a5706cd5ff8a2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
